// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include "ttnn/deprecated/tt_dnn/kernels/dataflow/generate_mm_scaler.hpp"
#include "ttnn/deprecated/tt_dnn/kernels/dataflow/moreh_common.hpp"

void kernel_main() {
    uint32_t src_addr = get_arg_val<uint32_t>(0);
    uint32_t num_tiles = get_arg_val<uint32_t>(1);
    uint32_t start_id = get_arg_val<uint32_t>(2);
    uint32_t mask_w = get_arg_val<uint32_t>(3);
    constexpr auto src_args = TensorAccessorArgs<0>();
    constexpr uint32_t scaler = get_compile_time_arg_val(src_args.next_compile_time_args_offset());

    constexpr uint32_t cb_id_in2 = tt::CBIndex::c_2;
    generate_mm_scaler(cb_id_in2, scaler);

    constexpr uint32_t cb_id_mask_w = tt::CBIndex::c_3;
#ifdef DO_MASK_W
    generate_mask_w(cb_id_mask_w, mask_w);
#endif

    constexpr uint32_t cb_id_in0 = tt::CBIndex::c_0;

    // ublocks size defined in tiles
    constexpr uint32_t onetile = 1;
    uint32_t tile_bytes = get_tile_size(cb_id_in0);

    const auto s = TensorAccessor(src_args, src_addr, tile_bytes);

    // read a ublock of tiles from src to CB, and then push the ublock to unpacker
    for (uint32_t i = start_id; i < start_id + num_tiles; i++) {
        cb_reserve_back(cb_id_in0, onetile);
        uint32_t l1_write_addr = get_write_ptr(cb_id_in0);
        noc_async_read_tile(i, s, l1_write_addr);
        noc_async_read_barrier();
        cb_push_back(cb_id_in0, onetile);
    }
}
